{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPU Light Gradient boosting trained on timestamp text data-set\n",
    "\n",
    "1. Same emotion dataset from [NLP-dataset](https://github.com/huseinzol05/NLP-Dataset)\n",
    "2. Same splitting 80% training, 20% testing, may vary depends on randomness\n",
    "3. Same regex substitution '[^\\\"\\'A-Za-z0-9 ]+'\n",
    "\n",
    "## Example\n",
    "\n",
    "Based on sorted dictionary position\n",
    "\n",
    "text: 'module into which all the refactored classes', matrix: [167, 143, 12, 3, 4, 90]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sklearn.datasets\n",
    "import re\n",
    "import time\n",
    "import lightgbm as lgb\n",
    "import pickle\n",
    "from sklearn.cross_validation import train_test_split\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def clearstring(string):\n",
    "    string = re.sub('[^\\\"\\'A-Za-z0-9 ]+', '', string)\n",
    "    string = string.split(' ')\n",
    "    string = filter(None, string)\n",
    "    string = [y.strip() for y in string]\n",
    "    string = ' '.join(string)\n",
    "    return string\n",
    "\n",
    "# because of sklean.datasets read a document as a single element\n",
    "# so we want to split based on new line\n",
    "def separate_dataset(trainset):\n",
    "    datastring = []\n",
    "    datatarget = []\n",
    "    for i in range(len(trainset.data)):\n",
    "        data_ = trainset.data[i].split('\\n')\n",
    "        # python3, if python2, just remove list()\n",
    "        data_ = list(filter(None, data_))\n",
    "        for n in range(len(data_)):\n",
    "            data_[n] = clearstring(data_[n])\n",
    "        datastring += data_\n",
    "        for n in range(len(data_)):\n",
    "            datatarget.append(trainset.target[i])\n",
    "    return datastring, datatarget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainset_data = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
    "trainset_data.data, trainset_data.target = separate_dataset(trainset_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('dictionary_emotion.p', 'rb') as fopen:\n",
    "    dict_emotion = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "len_sentences = np.array([len(i.split()) for i in trainset_data.data])\n",
    "maxlen = np.ceil(len_sentences.mean()).astype('int')\n",
    "data_X = np.zeros((len(trainset_data.data), maxlen))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for i in range(data_X.shape[0]):\n",
    "    tokens = trainset_data.data[i].split()[:maxlen]\n",
    "    for no, text in enumerate(tokens[::-1]):\n",
    "        try:\n",
    "            data_X[i, -1 - no] = dict_emotion[text]\n",
    "        except:\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(data_X, trainset_data.target, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params_lgb = {\n",
    "    'max_depth': 27, \n",
    "    'learning_rate': 0.03,\n",
    "    'verbose': 50, \n",
    "    'early_stopping_round': 200,\n",
    "    'metric': 'multi_logloss',\n",
    "    'objective': 'multiclass',\n",
    "    'num_classes': len(trainset_data.target_names),\n",
    "    'device': 'gpu',\n",
    "    'gpu_platform_id': 0,\n",
    "    'gpu_device_id': 0\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/lightgbm/engine.py:104: UserWarning: Found `early_stopping_round` in params. Will use it instead of argument\n",
      "  warnings.warn(\"Found `{}` in params. Will use it instead of argument\".format(alias))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training until validation scores don't improve for 200 rounds.\n",
      "[100]\ttraining's multi_logloss: 1.53707\tvalid_1's multi_logloss: 1.54177\n",
      "[200]\ttraining's multi_logloss: 1.50251\tvalid_1's multi_logloss: 1.51202\n",
      "[300]\ttraining's multi_logloss: 1.48103\tvalid_1's multi_logloss: 1.49636\n",
      "[400]\ttraining's multi_logloss: 1.46418\tvalid_1's multi_logloss: 1.48556\n",
      "[500]\ttraining's multi_logloss: 1.44901\tvalid_1's multi_logloss: 1.47635\n",
      "[600]\ttraining's multi_logloss: 1.43547\tvalid_1's multi_logloss: 1.46905\n",
      "[700]\ttraining's multi_logloss: 1.42315\tvalid_1's multi_logloss: 1.4629\n",
      "[800]\ttraining's multi_logloss: 1.41188\tvalid_1's multi_logloss: 1.45773\n",
      "[900]\ttraining's multi_logloss: 1.40093\tvalid_1's multi_logloss: 1.45285\n",
      "[1000]\ttraining's multi_logloss: 1.39076\tvalid_1's multi_logloss: 1.44862\n",
      "[1100]\ttraining's multi_logloss: 1.3808\tvalid_1's multi_logloss: 1.44438\n",
      "[1200]\ttraining's multi_logloss: 1.3712\tvalid_1's multi_logloss: 1.4405\n",
      "[1300]\ttraining's multi_logloss: 1.36209\tvalid_1's multi_logloss: 1.43713\n",
      "[1400]\ttraining's multi_logloss: 1.35336\tvalid_1's multi_logloss: 1.43382\n",
      "[1500]\ttraining's multi_logloss: 1.34472\tvalid_1's multi_logloss: 1.43059\n",
      "[1600]\ttraining's multi_logloss: 1.3365\tvalid_1's multi_logloss: 1.42776\n",
      "[1700]\ttraining's multi_logloss: 1.32831\tvalid_1's multi_logloss: 1.42494\n",
      "[1800]\ttraining's multi_logloss: 1.32025\tvalid_1's multi_logloss: 1.42216\n",
      "[1900]\ttraining's multi_logloss: 1.31241\tvalid_1's multi_logloss: 1.41958\n",
      "[2000]\ttraining's multi_logloss: 1.30465\tvalid_1's multi_logloss: 1.41694\n",
      "[2100]\ttraining's multi_logloss: 1.29715\tvalid_1's multi_logloss: 1.41464\n",
      "[2200]\ttraining's multi_logloss: 1.28958\tvalid_1's multi_logloss: 1.41204\n",
      "[2300]\ttraining's multi_logloss: 1.28222\tvalid_1's multi_logloss: 1.4097\n",
      "[2400]\ttraining's multi_logloss: 1.27517\tvalid_1's multi_logloss: 1.40769\n",
      "[2500]\ttraining's multi_logloss: 1.26826\tvalid_1's multi_logloss: 1.4059\n",
      "[2600]\ttraining's multi_logloss: 1.26151\tvalid_1's multi_logloss: 1.40413\n",
      "[2700]\ttraining's multi_logloss: 1.25492\tvalid_1's multi_logloss: 1.40257\n",
      "[2800]\ttraining's multi_logloss: 1.24822\tvalid_1's multi_logloss: 1.40084\n",
      "[2900]\ttraining's multi_logloss: 1.24163\tvalid_1's multi_logloss: 1.39923\n",
      "[3000]\ttraining's multi_logloss: 1.23505\tvalid_1's multi_logloss: 1.3975\n",
      "[3100]\ttraining's multi_logloss: 1.22865\tvalid_1's multi_logloss: 1.39592\n",
      "[3200]\ttraining's multi_logloss: 1.22248\tvalid_1's multi_logloss: 1.39454\n",
      "[3300]\ttraining's multi_logloss: 1.21628\tvalid_1's multi_logloss: 1.39304\n",
      "[3400]\ttraining's multi_logloss: 1.2101\tvalid_1's multi_logloss: 1.3916\n",
      "[3500]\ttraining's multi_logloss: 1.20415\tvalid_1's multi_logloss: 1.39032\n",
      "[3600]\ttraining's multi_logloss: 1.19815\tvalid_1's multi_logloss: 1.38911\n",
      "[3700]\ttraining's multi_logloss: 1.19212\tvalid_1's multi_logloss: 1.38771\n",
      "[3800]\ttraining's multi_logloss: 1.18649\tvalid_1's multi_logloss: 1.38669\n",
      "[3900]\ttraining's multi_logloss: 1.18075\tvalid_1's multi_logloss: 1.38558\n",
      "[4000]\ttraining's multi_logloss: 1.17521\tvalid_1's multi_logloss: 1.38461\n",
      "[4100]\ttraining's multi_logloss: 1.16962\tvalid_1's multi_logloss: 1.3836\n",
      "[4200]\ttraining's multi_logloss: 1.16409\tvalid_1's multi_logloss: 1.38266\n",
      "[4300]\ttraining's multi_logloss: 1.1586\tvalid_1's multi_logloss: 1.38164\n",
      "[4400]\ttraining's multi_logloss: 1.15335\tvalid_1's multi_logloss: 1.38089\n",
      "[4500]\ttraining's multi_logloss: 1.14807\tvalid_1's multi_logloss: 1.38003\n",
      "[4600]\ttraining's multi_logloss: 1.14272\tvalid_1's multi_logloss: 1.37912\n",
      "[4700]\ttraining's multi_logloss: 1.1375\tvalid_1's multi_logloss: 1.37816\n",
      "[4800]\ttraining's multi_logloss: 1.13255\tvalid_1's multi_logloss: 1.37747\n",
      "[4900]\ttraining's multi_logloss: 1.1275\tvalid_1's multi_logloss: 1.37667\n",
      "[5000]\ttraining's multi_logloss: 1.12246\tvalid_1's multi_logloss: 1.37592\n",
      "[5100]\ttraining's multi_logloss: 1.11751\tvalid_1's multi_logloss: 1.37518\n",
      "[5200]\ttraining's multi_logloss: 1.1126\tvalid_1's multi_logloss: 1.37447\n",
      "[5300]\ttraining's multi_logloss: 1.10793\tvalid_1's multi_logloss: 1.37393\n",
      "[5400]\ttraining's multi_logloss: 1.10301\tvalid_1's multi_logloss: 1.3732\n",
      "[5500]\ttraining's multi_logloss: 1.09824\tvalid_1's multi_logloss: 1.37254\n",
      "[5600]\ttraining's multi_logloss: 1.09352\tvalid_1's multi_logloss: 1.37184\n",
      "[5700]\ttraining's multi_logloss: 1.0888\tvalid_1's multi_logloss: 1.37114\n",
      "[5800]\ttraining's multi_logloss: 1.08416\tvalid_1's multi_logloss: 1.37057\n",
      "[5900]\ttraining's multi_logloss: 1.07958\tvalid_1's multi_logloss: 1.37002\n",
      "[6000]\ttraining's multi_logloss: 1.07519\tvalid_1's multi_logloss: 1.36958\n",
      "[6100]\ttraining's multi_logloss: 1.07072\tvalid_1's multi_logloss: 1.36912\n",
      "[6200]\ttraining's multi_logloss: 1.06616\tvalid_1's multi_logloss: 1.36858\n",
      "[6300]\ttraining's multi_logloss: 1.06168\tvalid_1's multi_logloss: 1.36808\n",
      "[6400]\ttraining's multi_logloss: 1.05728\tvalid_1's multi_logloss: 1.36756\n",
      "[6500]\ttraining's multi_logloss: 1.05301\tvalid_1's multi_logloss: 1.36718\n",
      "[6600]\ttraining's multi_logloss: 1.04858\tvalid_1's multi_logloss: 1.36664\n",
      "[6700]\ttraining's multi_logloss: 1.04434\tvalid_1's multi_logloss: 1.36629\n",
      "[6800]\ttraining's multi_logloss: 1.04014\tvalid_1's multi_logloss: 1.36588\n",
      "[6900]\ttraining's multi_logloss: 1.0358\tvalid_1's multi_logloss: 1.36533\n",
      "[7000]\ttraining's multi_logloss: 1.03171\tvalid_1's multi_logloss: 1.36501\n",
      "[7100]\ttraining's multi_logloss: 1.02758\tvalid_1's multi_logloss: 1.36464\n",
      "[7200]\ttraining's multi_logloss: 1.02343\tvalid_1's multi_logloss: 1.36431\n",
      "[7300]\ttraining's multi_logloss: 1.01934\tvalid_1's multi_logloss: 1.36392\n",
      "[7400]\ttraining's multi_logloss: 1.01527\tvalid_1's multi_logloss: 1.36362\n",
      "[7500]\ttraining's multi_logloss: 1.01121\tvalid_1's multi_logloss: 1.36322\n",
      "[7600]\ttraining's multi_logloss: 1.00723\tvalid_1's multi_logloss: 1.36294\n",
      "[7700]\ttraining's multi_logloss: 1.00326\tvalid_1's multi_logloss: 1.36256\n",
      "[7800]\ttraining's multi_logloss: 0.999308\tvalid_1's multi_logloss: 1.36218\n",
      "[7900]\ttraining's multi_logloss: 0.995404\tvalid_1's multi_logloss: 1.36188\n",
      "[8000]\ttraining's multi_logloss: 0.991669\tvalid_1's multi_logloss: 1.36168\n",
      "[8100]\ttraining's multi_logloss: 0.987947\tvalid_1's multi_logloss: 1.36142\n",
      "[8200]\ttraining's multi_logloss: 0.984093\tvalid_1's multi_logloss: 1.36116\n",
      "[8300]\ttraining's multi_logloss: 0.980374\tvalid_1's multi_logloss: 1.36089\n",
      "[8400]\ttraining's multi_logloss: 0.97663\tvalid_1's multi_logloss: 1.36058\n",
      "[8500]\ttraining's multi_logloss: 0.972885\tvalid_1's multi_logloss: 1.3603\n",
      "[8600]\ttraining's multi_logloss: 0.969248\tvalid_1's multi_logloss: 1.36006\n",
      "[8700]\ttraining's multi_logloss: 0.965566\tvalid_1's multi_logloss: 1.35979\n",
      "[8800]\ttraining's multi_logloss: 0.962\tvalid_1's multi_logloss: 1.35965\n",
      "[8900]\ttraining's multi_logloss: 0.95848\tvalid_1's multi_logloss: 1.35952\n",
      "[9000]\ttraining's multi_logloss: 0.954858\tvalid_1's multi_logloss: 1.35931\n",
      "[9100]\ttraining's multi_logloss: 0.951242\tvalid_1's multi_logloss: 1.35907\n",
      "[9200]\ttraining's multi_logloss: 0.947644\tvalid_1's multi_logloss: 1.35882\n",
      "[9300]\ttraining's multi_logloss: 0.944183\tvalid_1's multi_logloss: 1.35862\n",
      "[9400]\ttraining's multi_logloss: 0.940645\tvalid_1's multi_logloss: 1.35843\n",
      "[9500]\ttraining's multi_logloss: 0.937186\tvalid_1's multi_logloss: 1.35832\n",
      "[9600]\ttraining's multi_logloss: 0.933673\tvalid_1's multi_logloss: 1.3581\n",
      "[9700]\ttraining's multi_logloss: 0.930321\tvalid_1's multi_logloss: 1.35805\n",
      "[9800]\ttraining's multi_logloss: 0.926898\tvalid_1's multi_logloss: 1.35787\n",
      "[9900]\ttraining's multi_logloss: 0.923373\tvalid_1's multi_logloss: 1.35763\n",
      "[10000]\ttraining's multi_logloss: 0.919989\tvalid_1's multi_logloss: 1.35742\n",
      "[10100]\ttraining's multi_logloss: 0.916565\tvalid_1's multi_logloss: 1.35725\n",
      "[10200]\ttraining's multi_logloss: 0.913205\tvalid_1's multi_logloss: 1.357\n",
      "[10300]\ttraining's multi_logloss: 0.909884\tvalid_1's multi_logloss: 1.35682\n",
      "[10400]\ttraining's multi_logloss: 0.906492\tvalid_1's multi_logloss: 1.35662\n",
      "[10500]\ttraining's multi_logloss: 0.903154\tvalid_1's multi_logloss: 1.35648\n",
      "[10600]\ttraining's multi_logloss: 0.899909\tvalid_1's multi_logloss: 1.35643\n",
      "[10700]\ttraining's multi_logloss: 0.896722\tvalid_1's multi_logloss: 1.35634\n",
      "[10800]\ttraining's multi_logloss: 0.893527\tvalid_1's multi_logloss: 1.35624\n",
      "[10900]\ttraining's multi_logloss: 0.890329\tvalid_1's multi_logloss: 1.35616\n",
      "[11000]\ttraining's multi_logloss: 0.887121\tvalid_1's multi_logloss: 1.35615\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[11100]\ttraining's multi_logloss: 0.883993\tvalid_1's multi_logloss: 1.35606\n",
      "[11200]\ttraining's multi_logloss: 0.880897\tvalid_1's multi_logloss: 1.35605\n",
      "[11300]\ttraining's multi_logloss: 0.877782\tvalid_1's multi_logloss: 1.35597\n",
      "[11400]\ttraining's multi_logloss: 0.874686\tvalid_1's multi_logloss: 1.35596\n",
      "[11500]\ttraining's multi_logloss: 0.871599\tvalid_1's multi_logloss: 1.3559\n",
      "[11600]\ttraining's multi_logloss: 0.868433\tvalid_1's multi_logloss: 1.35582\n",
      "[11700]\ttraining's multi_logloss: 0.865309\tvalid_1's multi_logloss: 1.3557\n",
      "[11800]\ttraining's multi_logloss: 0.862323\tvalid_1's multi_logloss: 1.3557\n",
      "[11900]\ttraining's multi_logloss: 0.859349\tvalid_1's multi_logloss: 1.35567\n",
      "[12000]\ttraining's multi_logloss: 0.85648\tvalid_1's multi_logloss: 1.35575\n",
      "[12100]\ttraining's multi_logloss: 0.853534\tvalid_1's multi_logloss: 1.35577\n",
      "Early stopping, best iteration is:\n",
      "[11908]\ttraining's multi_logloss: 0.859119\tvalid_1's multi_logloss: 1.35566\n",
      "1136.274 Seconds to train lgb\n"
     ]
    }
   ],
   "source": [
    "d_train = lgb.Dataset(train_X, train_Y)\n",
    "d_valid = lgb.Dataset(test_X, test_Y)\n",
    "watchlist = [d_train, d_valid]\n",
    "t=time.time()\n",
    "clf = lgb.train(params_lgb, d_train, 100000, watchlist, early_stopping_rounds=200, verbose_eval=100)\n",
    "print(round(time.time()-t, 3), 'Seconds to train lgb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.47172572635013554"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(test_Y == np.argmax(clf.predict(test_X), axis = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "clf.save_model('lgb-timestamp.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "      anger       0.53      0.22      0.31     11587\n",
      "       fear       0.50      0.18      0.27      9504\n",
      "        joy       0.46      0.73      0.57     28074\n",
      "       love       0.30      0.08      0.13      6949\n",
      "    sadness       0.49      0.57      0.53     24293\n",
      "   surprise       0.26      0.09      0.13      2955\n",
      "\n",
      "avg / total       0.46      0.47      0.43     83362\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "print(metrics.classification_report(test_Y, np.argmax(clf.predict(test_X), axis = 1), target_names = trainset_data.target_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
